Trees





Kerry Back

A decision tree

Prediction in each cell is the plurality class (for classification) or the cell mean (for regression).

Another example

Splitting criterion for classification

  • In each cell, prediction is class with most representation.
  • Each observation of other classes is an error.
  • Try to create “pure” classes.
  • Perfect purity means each cell contains only one class
    \(\Rightarrow\) no errors.

Splitting criterion for regression

  • In each cell, prediction is mean.
  • Usually try to minimize sum of squared errors.
  • Algorithm will try to find splits that separate outliers into their own cells.
  • To avoid dependence on outliers,
    • Minimize sum of absolute errors instead, or
    • Choose target variable that does not have outliers

Example: ROEQ and MOM12M in 2021-12


Get data from the SQL database as before

Fit a classification tree

from sklearn.tree import DecisionTreeClassifier

data['class'] = data.ret.transform(
  lambda x: pd.qcut(x, 3, labels=(0, 1, 2))
)
X = data[["roeq", "mom12m"]]
y = data["class"]

model = DecisionTreeClassifier(
  max_depth=2, 
  random_state=0
)
model.fit(X, y)

View the classification tree

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plot_tree(model)
plt.show()

Confusion matrix

from sklearn.metrics import ConfusionMatrixDisplay
ConfusionMatrixDisplay.from_estimator(model, X=X, y=y)
plt.show()

Predicted class probabilities

  • Three of the four leaves have a plurality of High, so all observations in those leaves get a prediction of High.
  • But the three leaves are not the same.
  • The fraction of Highs in a leaf is the probability that an observation in the leaf is High. The probabilities are
    • 53/69 = 77%
    • 315/695 = 45%
    • 409/1664 = 25%
    • 70/114 = 61%

Fit a regression tree

from sklearn.tree import DecisionTreeRegressor

X = data[["roeq", "mom12m"]]
y = data["ret"]

model = DecisionTreeRegressor(
  max_depth=2,
  random_state=0
)
model.fit(X, y)

View the regression tree

plot_tree(model)
plt.show()

Predicting ranks

data['rnk'] = data.ret.rank(pct=True)

X = data[["roeq", "mom12m"]]
y = data["rnk"]

model = DecisionTreeRegressor(
  max_depth=2,
  random_state=0
)
model.fit(X, y)

View the regression tree for ranks

plot_tree(model)
plt.show()

Predicting numerical classes

X = data[["roeq", "mom12m"]]
y = data["class"]

model = DecisionTreeRegressor(
  max_depth=2,
  random_state=0
)
model.fit(X, y)

View the regression tree for classes

plot_tree(model)
plt.show()